{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Private Predictions with Syft Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Private Prediction using Syft Keras - Serving (Client)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congratulations! After training your model with normal Keras and securing it with Syft Keras, you are ready to request some private predictions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.datasets import mnist\n",
    "\n",
    "import syft as sy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we preprocess our MNIST data. This is identical to how we preprocessed during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input image dimensions\n",
    "img_rows, img_cols = 28, 28\n",
    "\n",
    "# the data, split between train and test sets\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
    "x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
    "input_shape = (img_rows, img_cols, 1)\n",
    "\n",
    "x_train = x_train.astype('float32')\n",
    "x_test = x_test.astype('float32')\n",
    "x_train /= 255\n",
    "x_test /= 255"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connect to model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before querying the model, you just have to connect to it. To do so, you can create a client. Then define the exact same three TFEWorkers (`alice`, `bob`, and `carol`) and cluster. Finally call `connect_to_model`.  This creates a TFE queueing server on the client side that connects to the queueing server set up by `model.serve()` in **Part 13b**. The queue will be responsible for secret sharing the plaintext data before submitting the shares in a prediction request."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes = 10\n",
    "input_shape = (1, 28, 28, 1)\n",
    "output_shape = (1, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tf_encrypted:Starting session on target 'grpc://localhost:4000' using config graph_options {\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "client = sy.TFEWorker()\n",
    "\n",
    "alice = sy.TFEWorker(host='localhost:4000')\n",
    "bob = sy.TFEWorker(host='localhost:4001')\n",
    "carol = sy.TFEWorker(host='localhost:4002')\n",
    "cluster = sy.TFECluster(alice, bob, carol)\n",
    "\n",
    "client.connect_to_model(input_shape, output_shape, cluster)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Query model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You are ready to get some private predictions! Calling `query_model` will insert the `image` into the queue created above, secret share the data locally, and submit the shares to the model server in **Part 13b**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# User inputs\n",
    "num_tests = 3\n",
    "images, expected_labels = x_test[:num_tests], y_test[:num_tests]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The image had label 7 and was correctly classified as 7\n",
      "The image had label 2 and was correctly classified as 2\n",
      "The image had label 1 and was correctly classified as 1\n"
     ]
    }
   ],
   "source": [
    "for image, expected_label in zip(images, expected_labels):\n",
    "\n",
    "    res = client.query_model(image.reshape(1, 28, 28, 1))\n",
    "    predicted_label = np.argmax(res)\n",
    "\n",
    "    print(\"The image had label {} and was {} classified as {}\".format(\n",
    "        expected_label,\n",
    "        \"correctly\" if expected_label == predicted_label else \"wrongly\",\n",
    "        predicted_label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is great. You are able to classify these three images correctly! But what's special about these predictions is that you haven't revealed any private information to get this service. The model host never saw your input data or your predictions, and you never downloaded the model. You were able to get private predictions on encrypted data with an encrypted model!\n",
    "\n",
    "Before we rush off to apply this in our own apps, let's quickly go back to **Part 13b** to clean up our served model!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
